{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "__This notebook__ fine-tunes a pre-trained resnet18 model with editable training.\n",
    "\n",
    "__Prepare data:__\n",
    "* Download imagenet training and dataset\n",
    "* Make sure folder names are called \"000\", \"001\", ... \"010\", \"011\", ... and not \"0\", \"1\", ..., \"10\", \"11\", ...\n",
    "    * rename if necessary\n",
    "* Run `imagenet_preprocess_logits.ipynb` to prepare fine-tuning metadata.\n",
    "\n",
    "__Training:__\n",
    "* Set environment variables and paths in the next cell\n",
    "* Run all cells :)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: CUDA_VISIBLE_DEVICES=1\n",
      "imagenet_editable_extra_layer_2019.09.19_23:11:24\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%env CUDA_VISIBLE_DEVICES=YOURDEVICEHERE\n",
    "\n",
    "traindir = '../../imagenet/train'  # path to train ImageFolder\n",
    "valdir = '../../imagenet/val'      # path to validation ImageFolder\n",
    "logits_path = './imagenet_logits/' # see imagenet_preprocess_logits\n",
    "\n",
    "import os, sys, time\n",
    "sys.path.insert(0, '..')\n",
    "import lib\n",
    "\n",
    "import numpy as np\n",
    "import torch, torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import random\n",
    "random.seed(42)\n",
    "np.random.seed(42)\n",
    "torch.random.manual_seed(42)\n",
    "\n",
    "import time\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "experiment_name = 'imagenet_editable_extra_layer'\n",
    "experiment_name = '{}_{}.{:0>2d}.{:0>2d}_{:0>2d}:{:0>2d}:{:0>2d}'.format(experiment_name, *time.gmtime()[:6])\n",
    "print(experiment_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.models as models\n",
    "\n",
    "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                                 std=[0.229, 0.224, 0.225])\n",
    "\n",
    "train_dataset = lib.ImageAndLogitsFolder(\n",
    "    traindir,\n",
    "    transforms.Compose([\n",
    "        transforms.RandomResizedCrop(224),\n",
    "        transforms.RandomHorizontalFlip(),\n",
    "        transforms.ToTensor(),\n",
    "        normalize,\n",
    "    ]),\n",
    "    logits_prefix = logits_path\n",
    ")\n",
    "\n",
    "batch_size = 128\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    train_dataset, batch_size=batch_size, shuffle=True,\n",
    "    num_workers=12, pin_memory=True)\n",
    "\n",
    "val_loader = torch.utils.data.DataLoader(\n",
    "    datasets.ImageFolder(valdir, transforms.Compose([\n",
    "        transforms.Resize(256),\n",
    "        transforms.CenterCrop(224),\n",
    "        transforms.ToTensor(),\n",
    "        normalize,\n",
    "    ])),\n",
    "    batch_size=batch_size, shuffle=False,\n",
    "    num_workers=32, pin_memory=True)\n",
    "\n",
    "X_test, y_test = map(torch.cat, zip(*val_loader))\n",
    "X_test, y_test = X_test[::10], y_test[::10]\n",
    "# !!!IMPORTANT!!!\n",
    "# We use 10% of validation samples for faster validation, please use full validation set to measure \"final\" error rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "\n",
    "model = torchvision.models.resnet18(pretrained=True)\n",
    "\n",
    "optimizer = lib.IngraphRMSProp(learning_rate=1e-4, beta=nn.Parameter(torch.as_tensor(0.5)))\n",
    "\n",
    "model = lib.SequentialWithEditable(\n",
    "    model.conv1, model.bn1, model.relu, model.maxpool,\n",
    "    model.layer1, model.layer2, model.layer3, model.layer4,\n",
    "    model.avgpool, lib.Flatten(),\n",
    "    lib.Editable(\n",
    "        lib.Residual(nn.Linear(512, 4096), nn.ELU(), nn.Linear(4096, 512)),\n",
    "        loss_function=lib.contrastive_cross_entropy, \n",
    "        optimizer=optimizer, max_steps=10),\n",
    "\n",
    "    model.fc\n",
    ").to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def classification_error(model, X_test, y_test):\n",
    "    with lib.training_mode(model, is_train=False):\n",
    "        return lib.classification_error(lib.Lambda(lambda x: model(x.to(device))),\n",
    "                                        X_test, y_test, device='cpu', batch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_params = set(model.editable.module[0].parameters())\n",
    "old_params = [param for param in model.parameters() if param not in new_params]\n",
    "\n",
    "training_opt = lib.OptimizerList(\n",
    "    torch.optim.SGD(old_params, lr=1e-5, momentum=0.9, weight_decay=1e-4),\n",
    "    torch.optim.SGD(new_params, lr=1e-3, momentum=0.9, weight_decay=1e-4),\n",
    ")\n",
    "\n",
    "trainer = lib.DistillationEditableTrainer(model,\n",
    "          stability_coeff=0.03, editability_coeff=0.03,\n",
    "          experiment_name=experiment_name,\n",
    "          error_function=classification_error,\n",
    "          opt=training_opt, max_norm=10)\n",
    "\n",
    "trainer.writer.add_text(\"trainer\", repr(trainer).replace('\\n', '<br>'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm_notebook, tnrange\n",
    "from IPython.display import clear_output\n",
    "\n",
    "# Learnign params\n",
    "eval_batch_cd = 500\n",
    "val_metrics = trainer.evaluate_metrics(X_test.to(device), y_test.to(device))\n",
    "min_error, min_drawdown = val_metrics['base_error'], val_metrics['drawdown']\n",
    "early_stopping_epochs = 500\n",
    "number_of_epochs_without_improvement = 0\n",
    "            \n",
    "def edit_generator():\n",
    "    while True:\n",
    "        for xb, yb, lg in torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2):\n",
    "            yield xb.to(device), torch.randint_like(yb, low=0, high=max(y_test) + 1, device=device)\n",
    "\n",
    "edit_generator = edit_generator()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "while True:\n",
    "    \n",
    "    for x_batch, y_batch, logits in tqdm_notebook(train_loader):\n",
    "        trainer.step(x_batch.to(device), logits.to(device), *next(edit_generator))\n",
    "        \n",
    "        if trainer.total_steps % eval_batch_cd == 0:\n",
    "            val_metrics = trainer.evaluate_metrics(X_test.to(device), y_test.to(device))\n",
    "            clear_output(True)\n",
    "\n",
    "            error_rate, drawdown = val_metrics['base_error'], val_metrics['drawdown']\n",
    "\n",
    "            number_of_epochs_without_improvement += 1\n",
    "\n",
    "            if error_rate < min_error:\n",
    "                trainer.save_checkpoint(tag='best_val_error')\n",
    "                min_error = error_rate\n",
    "                number_of_epochs_without_improvement = 0\n",
    "\n",
    "            if drawdown < min_drawdown:\n",
    "                trainer.save_checkpoint(tag='best_drawdown')\n",
    "                min_drawdown = drawdown\n",
    "                number_of_epochs_without_improvement = 0\n",
    "\n",
    "            trainer.save_checkpoint()\n",
    "            trainer.remove_old_temp_checkpoints()\n",
    "\n",
    "            if number_of_epochs_without_improvement > early_stopping_epochs:\n",
    "                break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate drawdown\n",
    "\n",
    "__Note:__ this code evaluates quality on 10% of the validation set. In paper we use this subset when evaluating drawdown but we measure the base error on all 50k validation samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# edit quality\n",
    "\n",
    "from lib import evaluate_quality\n",
    "\n",
    "np.random.seed(9)\n",
    "indices = np.random.permutation(len(X_test))[:1000]\n",
    "X_edit = X_test[indices].clone().to(device)\n",
    "y_edit = torch.tensor(np.random.randint(0, max(y_test) + 1, size=y_test[indices].shape), device=device)\n",
    "metrics = evaluate_quality(model, X_test, y_test, X_edit, y_edit, \n",
    "                           error_function=classification_error, progressbar=tqdm_notebook)\n",
    "\n",
    "for key in sorted(metrics.keys()):\n",
    "    print('{}\\t:{:.5}'.format(key, metrics[key]))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
